from src.simulator import Simulator
import os
from src.discretizer import Discretizer
from src.abstraction import CentroidAbstraction, InterfaceAbstraction
from src.data_structures import Option, SingleValueSampler, State, Interface

from scipy.spatial.kdtree import KDTree
import pickle
import numpy as np
import time

class ProblemEvaluator(object):

    def __init__(self,problem_config,robot_config,mp_robot_config,runid,seeds,abstraction_type):
        self.runid = runid
        self.seeds = seeds
        self.problem_config = problem_config
        # self.problem_config['simulator_gui'] = True
        self.robot = robot_config
        self.mp_robot = mp_robot_config
        self.preds_path = os.path.join(self.problem_config["preds_path"],self.problem_config['env_name']+".npy")
        self.goal_tolerance = float(self.problem_config["goal_tolerance"])
        self.env_path = os.path.join(self.problem_config["env_mask_path"],self.problem_config['env_name']+".npy")
        self.policy_folder = os.path.join(self.problem_config['policy_folder'], runid)
        self.results_folder = self.problem_config["results_folder"]

        # Temp fix to save pytorch models in transit
        self.problem_config['policy_folder'] = self.policy_folder
        # self.policy_logs = self.problem_config['policy_logs']

        self.n_xy_bins = int(self.problem_config["n_xy_bins"])
        self.n_dof_bins = int(self.problem_config["n_dof_bins"])
        self.timeout_per_option = int(self.problem_config["timeout_per_option"])
        self.discretizer = Discretizer(self.robot,self.n_xy_bins,self.n_dof_bins)
        self.mp_discretizer = Discretizer(self.mp_robot,self.n_xy_bins,self.n_dof_bins)
        abs_fname = "{}_{}_abs.p".format(self.problem_config['env_name'],abstraction_type)
        if os.path.exists(os.path.join(self.policy_folder,abs_fname)):
            with open(os.path.join(self.policy_folder,abs_fname),"rb") as f:
                self.abstraction = pickle.load(f)
        else:
            self.nn_preds = self.load_pd()
            self.env_mask = self.load_env_mask()
            if abstraction_type == 'centroid':
                self.abstraction = CentroidAbstraction(self.nn_preds,self.discretizer,self.env_mask)
            elif abstraction_type == 'interface':
                self.abstraction = InterfaceAbstraction(self.nn_preds,self.discretizer,self.env_mask)
            else:
                raise ValueError("Undefined abstraction type {}. Choices: {}".format(abstraction_type, ", ".join(["centroid", "interface"])))
    

    def load_pd(self):
        pd = np.load(self.preds_path)
        pd = np.squeeze(pd)
        return pd
    
    def load_env_mask(self):
        env_mask = np.squeeze(np.load(self.env_path))[:,:,0]
        return env_mask

    def evaluate(self):

        experiment_fname = '{}_{}.pickle'.format(self.problem_config['robot']['name'], self.problem_config['env_name'])
        experiment = os.path.join(self.problem_config['experiments_path'],experiment_fname)
        with open(experiment, 'rb') as f:
            load_locations = pickle.load(f)
        self.action_log_dir = os.path.join(self.problem_config['action_logs'],
                                           self.problem_config['robot']['name'],
                                           self.problem_config['env_name'],
                                           self.runid)
        self.policy_log_folder = os.path.join(self.problem_config['policy_logs'],
                                           self.problem_config['robot']['name'],
                                           self.problem_config['env_name'],
                                           self.runid)

        results = {}
        for i in range(self.problem_config["num_test_problems"]):
            init = load_locations["init"][i][:3]
            goal = load_locations["goal"][i][:3]
            for k in range(len(init) - 1):
                init[k] = init[k] * 3 
                goal[k] = goal[k] * 3
            init_sampler = SingleValueSampler(init)
            term_sampler = SingleValueSampler(goal)
            action_counts = []
            pass_fail = []
            rewards = []
            policy_file = os.path.join(self.policy_log_folder,"p{}.pickle".format(i+1))
            if not os.path.isfile(policy_file):
                continue
            with open(policy_file,"rb") as f:
                policies = pickle.load(f)
            count = 0
            # for j in range(20):
            for j in range(1):
                print("Running problem #{}: iteration #{}".format(i+1,j+1),flush=True)
                action_count = 0
                eval_func = self.abstraction
                simulator = Simulator(init_sampler,eval_func,self.problem_config,term_sampler,seed=self.seeds[j])
                for policy in policies:
                    policy_path, option_guide, option_switch, eval_func = policy
                    self.abstraction.eval = eval_func
                    print(policy_path,flush=True)
                    simulator.set_init_sample_and_eval_func(init_sampler,self.abstraction,term_sampler)
                    action_count += simulator.execute_policy(policy_path,option_guide,option_switch)
                current_ll_config = simulator.get_ll_config()
                action_counts.append(action_count)
                if self.abstraction.evaluate_function(current_ll_config) == 1:
                    print("Pass",flush=True)
                    count += 1
                    pass_fail.append(1)
                    rewards.append(-1*(action_count-1)+1000)
                else:
                    print("Fail",flush=True)
                    rewards.append(-1*action_count)
                    pass_fail.append(0)
                time.sleep(3)
                del(simulator)
            results[i] = {"action_counts":action_counts,"rewards":rewards,"pass_fail":pass_fail,"succ":count/(j+1)}
        
        print(results,flush=True)
        with open(os.path.join(self.results_folder, "{}_{}_{}".format(self.runid,self.problem_config["env_name"],self.problem_config["robot"]["name"])),"wb") as f:
            pickle.dump(results,f)
            

            
